
package nsalt

import chisel3._
import chisel3.util._

import nsalt._
import nsalt.bus._
import nsalt.fetch.branch._

// Address of Initial PC
trait ResetVectorProvided {
  val resetVector = 0x80000000L
}

// Originally IFU. The embedded version. Found at:
// https://github.com/OSCPU/NutShell/blob/fd86beadfc47f52973270ce6109edebd2a30363b/src/main/scala/nutcore/frontend/IFU.scala#L259

class Fetch extends Module with ResetVectorProvided with Config {

  val io = IO(new Bundle {
    val imem = new BusUncached(userBits = 64, addrBits = VIRT_MEM_ADDR_LEN)
    val out = Decoupled(new CtrlFlowPort())
    val redirect = Flipped(new RedirectPort())
    
    // ========================================================================
    // Two flush signal for different purpose:
    // flushPipe controls the flushing of the pipeline, or part of which.
    // flushPred only responds for Branch prediction.

    val flushPipe = Output(UInt(4.W))
    val flushPred = Output(Bool())
    
    // pagefault singal from TLB
    val ipf = Input(Bool())
  })

  // PC / Program Counter / Instruction Pointer
  // the current and successive one.
  val pcCurr = RegInit(resetVector.U(32.W))
  val pcSucc = pcCurr + 4.U

  val pcWillUpdate = io.redirect.valid || io.imem.req.fire

  // Predicted next PC
  val predict = Module(new Predict())
  
  val pcPred = predict.io.out.dest
  val pcNext = Mux(io.redirect.valid, io.redirect.dest, Mux(predict.io.out.valid, pcPred, pcSucc))

  // the predicted next PC will be sent back to BranchPrediction for the prediction at
  // next cycle.
  predict.io.in.pc.valid := io.imem.req.fire
  predict.io.in.pc.bits  := pcNext
  predict.io.in.flush    := io.redirect.valid

  // tik tok
  when (pcWillUpdate) {
    pcCurr := pcNext
  }

  io.flushPipe := Mux(io.redirect.valid, "b1111".U, 0.U)
  io.flushPred := false.B

  io.out.bits := DontCare
  io.out.bits.instr := io.imem.res.bits.dataR

  io.imem.res.bits.user.map { case x =>
    io.out.bits.pc     := x(2 * VIRT_MEM_ADDR_LEN - 1, VIRT_MEM_ADDR_LEN)
    io.out.bits.pcPred := x(    VIRT_MEM_ADDR_LEN - 1,                 0)
  }

  io.out.valid := io.imem.res.valid && !io.flushPipe(0)
}
